package cn.uncode.rpc.spring;


import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

import org.apache.commons.lang3.StringUtils;
import org.springframework.aop.support.AopUtils;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.beans.factory.BeanInitializationException;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;

import cn.uncode.rpc.common.log.Logger;
import cn.uncode.rpc.common.log.LoggerFactory;
import cn.uncode.rpc.config.BasicCallerConfig;
import cn.uncode.rpc.config.BasicProviderConfig;
import cn.uncode.rpc.config.ExtConfig;
import cn.uncode.rpc.config.ProtocolConfig;
import cn.uncode.rpc.config.RegistryConfig;
import cn.uncode.rpc.spring.annotation.Caller;
import cn.uncode.rpc.spring.annotation.Provider;
import cn.uncode.rpc.util.ConcurrentHashSet;
import cn.uncode.rpc.util.ConfigUtil;
import cn.uncode.rpc.util.SpringBeanUtil;



/**
 * Annotation bean for uncode
 * 
 */
public class AnnotationBean implements DisposableBean, BeanFactoryPostProcessor, BeanPostProcessor, BeanFactoryAware {
	
	private static final Logger LOGGER = LoggerFactory.getLogger(AnnotationBean.class);


    private String id;

    private String annotationPackage;

    private String[] annotationPackages;

    private BeanFactory beanFactory;

    public AnnotationBean() {}
    
    private final Set<ProviderConfigBean<?>> providerConfigs = new ConcurrentHashSet<ProviderConfigBean<?>>();
    
    private final ConcurrentMap<String, CallerConfigBean> callerConfigs = new ConcurrentHashMap<String, CallerConfigBean>();

    /**
     * @param beanFactory
     * @throws BeansException
     */
    @Override
    public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory)
            throws BeansException {
        if (annotationPackage == null || annotationPackage.length() == 0) {
            return;
        }
        if (beanFactory instanceof BeanDefinitionRegistry) {
            try {
                // init scanner
                Class<?> scannerClass = ClassUtils.forName("org.springframework.context.annotation.ClassPathBeanDefinitionScanner",
                        AnnotationBean.class.getClassLoader());
                Object scanner = scannerClass.getConstructor(new Class<?>[]{BeanDefinitionRegistry.class, boolean.class})
                        .newInstance(new Object[]{(BeanDefinitionRegistry) beanFactory, true});
                // add filter
                Class<?> filterClass = ClassUtils.forName("org.springframework.core.type.filter.AnnotationTypeFilter",
                        AnnotationBean.class.getClassLoader());
                Object filter = filterClass.getConstructor(Class.class).newInstance(Provider.class);
                Method addIncludeFilter = scannerClass.getMethod("addIncludeFilter",
                        ClassUtils.forName("org.springframework.core.type.filter.TypeFilter", AnnotationBean.class.getClassLoader()));
                addIncludeFilter.invoke(scanner, filter);
                // scan packages
                Method scan = scannerClass.getMethod("scan", new Class<?>[]{String[].class});
                scan.invoke(scanner, new Object[]{annotationPackages});
            } catch (Throwable e) {
                // spring 2.0
            }
        }
    }

    /**
     * init caller field
     *
     * @param bean
     * @param beanName
     * @return
     * @throws BeansException
     */
    @Override
    public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
        if (!isMatchPackage(bean)) {
            return bean;
        }
        Class<?> clazz = bean.getClass();
        if (isProxyBean(bean)) {
            clazz = AopUtils.getTargetClass(bean);
        }
        Method[] methods = clazz.getMethods();
        for (Method method : methods) {
            String name = method.getName();
            if (name.length() > 3 && name.startsWith("set")
                    && method.getParameterTypes().length == 1
                    && Modifier.isPublic(method.getModifiers())
                    && !Modifier.isStatic(method.getModifiers())) {
                try {
                    Caller caller = method.getAnnotation(Caller.class);
                    if (caller != null) {
                        Object value = refer(caller, method.getParameterTypes()[0]);
                        if (value != null) {
                            method.invoke(bean, new Object[]{value});
                        }
                    }
                } catch (Exception e) {
                    throw new BeanInitializationException("Failed to init remote service reference at method " + name
                            + " in class " + bean.getClass().getName(), e);
                }
            }
        }


        Field[] fields = clazz.getDeclaredFields();
        for (Field field : fields) {
            try {
                if (!field.isAccessible()) {
                    field.setAccessible(true);
                }
                Caller caller = field.getAnnotation(Caller.class);
                if (caller != null) {
                    Object value = refer(caller, field.getType());
                    if (value != null) {
                        field.set(bean, value);
                        ReflectionUtils.setField(field, bean, value);
                    }
                }
              /*if (caller != null) {
            	  CallerConfigBean<?> callerConfigBean = refer2(caller, field.getType());
            	  if(beanFactory instanceof DefaultListableBeanFactory){
            		  ((DefaultListableBeanFactory)beanFactory).applyBeanPostProcessorsAfterInitialization(callerConfigBean, field.getName());
            	  }
            	  beanFactory.getBean(field.getName());
              }*/
            } catch (Exception e) {
                throw new BeanInitializationException("Failed to init remote service reference at filed " + field.getName()
                        + " in class " + bean.getClass().getName(), e);
            }
        }
        return bean;
    }

    /**
     * init service config and export servcice
     *
     * @param bean
     * @param beanName
     * @return
     * @throws BeansException
     */
    @Override
    public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
        if (!isMatchPackage(bean)) {
            return bean;
        }
        Class<?> clazz = bean.getClass();
        if (isProxyBean(bean)) {
            clazz = AopUtils.getTargetClass(bean);
        }
        Provider provider = clazz.getAnnotation(Provider.class);
        if (provider != null) {
            ProviderConfigBean<Object> providerConfig = new ProviderConfigBean<Object>();
            if (void.class.equals(provider.interfaceClass())) {
                if (clazz.getInterfaces().length > 0) {
                    Class<Object> clz = (Class<Object>) clazz.getInterfaces()[0];
                    providerConfig.setInterface(clz);
                } else {
                    throw new IllegalStateException("Failed to export remote service class " + clazz.getName()
                            + ", cause: The @Service undefined interfaceClass or interfaceName, and the service class unimplemented any interfaces.");
                }
            } else {
                providerConfig.setInterface((Class<Object>) provider.interfaceClass());
            }
            if (beanFactory != null) {

                providerConfig.setBeanFactory(beanFactory);
                if (provider.basicProvider() != null && provider.basicProvider().length() > 0) {
                    providerConfig.setBasicServiceConfig(beanFactory.getBean(provider.basicProvider(), BasicProviderConfig.class));
                }

                if (provider.export() != null && provider.export().length() > 0) {
                    providerConfig.setExport(provider.export());
                }

                if (provider.host() != null && provider.host().length() > 0) {
                    providerConfig.setHost(provider.host());
                }

                String protocolValue = null;
                if (provider.protocol() != null && provider.protocol().length() > 0) {
                    protocolValue = provider.protocol();
                } else if (provider.export() != null && provider.export().length() > 0) {
                    protocolValue = ConfigUtil.extractProtocols(provider.export());
                }

                if (!StringUtils.isBlank(protocolValue)) {
                    List<ProtocolConfig> protocolConfigs = SpringBeanUtil.getMultiBeans(beanFactory, protocolValue, SpringBeanUtil.COMMA_SPLIT_PATTERN,
                            ProtocolConfig.class);
                    providerConfig.setProtocols(protocolConfigs);
                }

                if (provider.registry() != null && provider.registry().length() > 0) {
                    List<RegistryConfig> registryConfigs = SpringBeanUtil.getMultiBeans(beanFactory, provider.registry
                            (), SpringBeanUtil.COMMA_SPLIT_PATTERN, RegistryConfig.class);
                    providerConfig.setRegistries(registryConfigs);
                }

                if (provider.extConfig() != null && provider.extConfig().length() > 0) {
                	providerConfig.setExtConfig(beanFactory.getBean(provider.extConfig(), ExtConfig.class));
                }

                if (provider.application() != null && provider.application().length() > 0) {
                    providerConfig.setApplication(provider.application());
                }
                if (provider.module() != null && provider.module().length() > 0) {
                    providerConfig.setModule(provider.module());
                }
                if (provider.group() != null && provider.group().length() > 0) {
                    providerConfig.setGroup(provider.group());
                }

                if (provider.version() != null && provider.version().length() > 0) {
                    providerConfig.setVersion(provider.version());
                }

                if (provider.proxy() != null && provider.proxy().length() > 0) {
                    providerConfig.setProxy(provider.proxy());
                }

                if (provider.filter() != null && provider.filter().length() > 0) {
                    providerConfig.setFilter(provider.filter());
                }


                if (provider.actives() > 0) {
                    providerConfig.setActives(provider.actives());
                }

                if(provider.async()) {
                    providerConfig.setAsync(provider.async());
                }

                if (provider.mock() != null && provider.mock().length() > 0) {
                    providerConfig.setMock(provider.mock());
                }


                // 是否共享 channel
                if (provider.shareChannel()) {
                    providerConfig.setShareChannel(provider.shareChannel());
                }

                // if throw exception when call failure，the default value is ture
                if (provider.throwException()) {
                    providerConfig.setThrowException(provider.throwException());
                }
                if(provider.requestTimeout()>0) {
                    providerConfig.setRequestTimeout(provider.requestTimeout());
                }
                if (provider.register()) {
                    providerConfig.setRegister(provider.register());
                }
                if (provider.accessLog()) {
                    providerConfig.setAccessLog("true");
                }
                if (provider.check()) {
                    providerConfig.setCheck("true");
                }
                if (provider.usegz()) {
                    providerConfig.setUsegz(provider.usegz());
                }

                if(provider.retries()>0) {
                    providerConfig.setRetries(provider.retries());
                }

                if(provider.mingzSize()>0) {
                    providerConfig.setMingzSize(provider.mingzSize());
                }

                if (provider.codec() != null && provider.codec().length() > 0) {
                    providerConfig.setCodec(provider.codec());
                }

                try {
                    providerConfig.afterPropertiesSet();
                } catch (RuntimeException e) {
                    throw (RuntimeException) e;
                } catch (Exception e) {
                    throw new IllegalStateException(e.getMessage(), e);
                }
            }
            providerConfig.setRef(bean);
            providerConfigs.add(providerConfig);
            providerConfig.export();
        }
        return bean;
    }

    /**
     * release service/reference
     *
     * @throws Exception
     */
    public void destroy() throws Exception {
        for (ProviderConfigBean<?> providerConfig : providerConfigs) {
            try {
                providerConfig.unexport();
            } catch (Throwable e) {
            	LOGGER.error(e.getMessage(), e);
            }
        }
        for (CallerConfigBean<?> referenceConfig : callerConfigs.values()) {
            try {
                referenceConfig.destroy();
            } catch (Throwable e) {
            	LOGGER.error(e.getMessage(), e);
            }
        }
    }
    
    
    /**
     * caller proxy
     *
     * @param reference
     * @param referenceClass
     * @param <T>
     * @return
     */
    private <T> CallerConfigBean<T> refer2(Caller caller, Class<T> callerClass) {

        String interfaceName;
        if (!void.class.equals(caller.interfaceClass())) {
            interfaceName = caller.interfaceClass().getName();
        } else if (callerClass.isInterface()) {
            interfaceName = callerClass.getName();
        } else {
            throw new IllegalStateException("The @Reference undefined interfaceClass or interfaceName, and the property type "
                    + callerClass.getName() + " is not a interface.");
        }
        String key = caller.group() + "/" + interfaceName + ":" + caller.version();
        CallerConfigBean<T> referenceConfig = callerConfigs.get(key);
        if (referenceConfig == null) {
            referenceConfig = new CallerConfigBean<T>();
            referenceConfig.setBeanFactory(beanFactory);
            if (void.class.equals(caller.interfaceClass())
                    && callerClass.isInterface()) {
                referenceConfig.setInterface(callerClass);
            } else if (!void.class.equals(caller.interfaceClass())) {
                referenceConfig.setInterface((Class<T>) caller.interfaceClass());
            }

            if (beanFactory != null) {
                if (caller.protocol() != null && caller.protocol().length() > 0) {
                    //多个PROTOCOL
					List<ProtocolConfig> protocolConfigs = SpringBeanUtil.getMultiBeans(beanFactory, caller.protocol(),
							SpringBeanUtil.COMMA_SPLIT_PATTERN, ProtocolConfig.class);
					referenceConfig.setProtocols(protocolConfigs);
                }

                if (caller.directUrl() != null && caller.directUrl().length() > 0) {
                    referenceConfig.setDirectUrl(caller.directUrl());
                }

                if (caller.basicCaller() != null && caller.basicCaller().length() > 0) {
                	BasicCallerConfig biConfig = beanFactory.getBean(caller.basicCaller(), BasicCallerConfig.class);
                    if (biConfig != null) {
                        referenceConfig.setBasicReferer(biConfig);
                    }
                }

                if (caller.client() != null && caller.client().length() > 0) {
                    //TODO?
//                    referenceConfig.setC(reference.client());
                }


//                String[] methods() default {};

                if (caller.registry() != null && caller.registry().length() > 0) {
                    List<RegistryConfig> registryConfigs = SpringBeanUtil.getMultiBeans(beanFactory, caller
                            .registry(), SpringBeanUtil.COMMA_SPLIT_PATTERN, RegistryConfig.class);
                    referenceConfig.setRegistries(registryConfigs);
                }

                if (caller.extConfig() != null && caller.extConfig().length() > 0) {
                    referenceConfig.setExtConfig(beanFactory.getBean(caller.extConfig(), ExtConfig.class));
                }

                if (caller.application() != null && caller.application().length() > 0) {
                    referenceConfig.setApplication(caller.application());
                }
                if (caller.module() != null && caller.module().length() > 0) {
                    referenceConfig.setModule(caller.module());
                }
                if (caller.group() != null && caller.group().length() > 0) {
                    referenceConfig.setGroup(caller.group());
                }

                if (caller.version() != null && caller.version().length() > 0) {
                    referenceConfig.setVersion(caller.version());
                }

                if (caller.proxy() != null && caller.proxy().length() > 0) {
                    referenceConfig.setProxy(caller.proxy());
                }

                if (caller.filter() != null && caller.filter().length() > 0) {
                    referenceConfig.setFilter(caller.filter());
                }


                if (caller.actives() > 0) {
                    referenceConfig.setActives(caller.actives());
                }

                if (caller.async()) {
                    referenceConfig.setAsync(caller.async());
                }


                if (caller.mock() != null && caller.mock().length() > 0) {
                    referenceConfig.setMock(caller.mock());
                }

                if (caller.shareChannel()) {
                    referenceConfig.setShareChannel(caller.shareChannel());
                }

                // if throw exception when call failure，the default value is ture
                if (caller.throwException()) {
                    referenceConfig.setThrowException(caller.throwException());
                }
                if(caller.requestTimeout()>0) {
                    referenceConfig.setRequestTimeout(caller.requestTimeout());
                }
                if (caller.register()) {
                    referenceConfig.setRegister(caller.register());
                }
                if (caller.accessLog()) {
                    referenceConfig.setAccessLog("true");
                }
                if (caller.check()) {
                    referenceConfig.setCheck("true");
                }
                if(caller.retries()>0) {
                    referenceConfig.setRetries(caller.retries());
                }
                if (caller.usegz()) {
                    referenceConfig.setUsegz(caller.usegz());
                }
                if(caller.mingzSize()>0) {
                    referenceConfig.setMingzSize(caller.mingzSize());
                }
                if (caller.codec() != null && caller.codec().length() > 0) {
                    referenceConfig.setCodec(caller.codec());
                }


                if (caller.mean() != null && caller.mean().length() > 0) {
                    referenceConfig.setMean(caller.mean());
                }
                if (caller.p90() != null && caller.p90().length() > 0) {
                    referenceConfig.setP90(caller.p90());
                }
                if (caller.p99() != null && caller.p99().length() > 0) {
                    referenceConfig.setP99(caller.p99());
                }
                if (caller.p999() != null && caller.p999().length() > 0) {
                    referenceConfig.setP999(caller.p999());
                }
                if (caller.errorRate() != null && caller.errorRate().length() > 0) {
                    referenceConfig.setErrorRate(caller.errorRate());
                }

                try {
                    referenceConfig.afterPropertiesSet();
                } catch (RuntimeException e) {
                    throw (RuntimeException) e;
                } catch (Exception e) {
                    throw new IllegalStateException(e.getMessage(), e);
                }
            }
            callerConfigs.putIfAbsent(key, referenceConfig);
            referenceConfig = callerConfigs.get(key);
        }

        return referenceConfig;
    
    }
    

    /**
     * caller proxy
     *
     * @param reference
     * @param referenceClass
     * @param <T>
     * @return
     */
    private <T> Object refer(Caller caller, Class<T> callerClass) {

        String interfaceName;
        if (!void.class.equals(caller.interfaceClass())) {
            interfaceName = caller.interfaceClass().getName();
        } else if (callerClass.isInterface()) {
            interfaceName = callerClass.getName();
        } else {
            throw new IllegalStateException("The @Reference undefined interfaceClass or interfaceName, and the property type "
                    + callerClass.getName() + " is not a interface.");
        }
        String key = caller.group() + "/" + interfaceName + ":" + caller.version();
        CallerConfigBean<T> referenceConfig = callerConfigs.get(key);
        if (referenceConfig == null) {
            referenceConfig = new CallerConfigBean<T>();
            referenceConfig.setBeanFactory(beanFactory);
            if (void.class.equals(caller.interfaceClass())
                    && callerClass.isInterface()) {
                referenceConfig.setInterface(callerClass);
            } else if (!void.class.equals(caller.interfaceClass())) {
                referenceConfig.setInterface((Class<T>) caller.interfaceClass());
            }

            if (beanFactory != null) {
                if (caller.protocol() != null && caller.protocol().length() > 0) {
                    //多个PROTOCOL
					List<ProtocolConfig> protocolConfigs = SpringBeanUtil.getMultiBeans(beanFactory, caller.protocol(),
							SpringBeanUtil.COMMA_SPLIT_PATTERN, ProtocolConfig.class);
					referenceConfig.setProtocols(protocolConfigs);
                }

                if (caller.directUrl() != null && caller.directUrl().length() > 0) {
                    referenceConfig.setDirectUrl(caller.directUrl());
                }

                if (caller.basicCaller() != null && caller.basicCaller().length() > 0) {
                	BasicCallerConfig biConfig = beanFactory.getBean(caller.basicCaller(), BasicCallerConfig.class);
                    if (biConfig != null) {
                        referenceConfig.setBasicReferer(biConfig);
                    }
                }

                if (caller.client() != null && caller.client().length() > 0) {
                    //TODO?
//                    referenceConfig.setC(reference.client());
                }


//                String[] methods() default {};

                if (caller.registry() != null && caller.registry().length() > 0) {
                    List<RegistryConfig> registryConfigs = SpringBeanUtil.getMultiBeans(beanFactory, caller
                            .registry(), SpringBeanUtil.COMMA_SPLIT_PATTERN, RegistryConfig.class);
                    referenceConfig.setRegistries(registryConfigs);
                }

                if (caller.extConfig() != null && caller.extConfig().length() > 0) {
                    referenceConfig.setExtConfig(beanFactory.getBean(caller.extConfig(), ExtConfig.class));
                }

                if (caller.application() != null && caller.application().length() > 0) {
                    referenceConfig.setApplication(caller.application());
                }
                if (caller.module() != null && caller.module().length() > 0) {
                    referenceConfig.setModule(caller.module());
                }
                if (caller.group() != null && caller.group().length() > 0) {
                    referenceConfig.setGroup(caller.group());
                }

                if (caller.version() != null && caller.version().length() > 0) {
                    referenceConfig.setVersion(caller.version());
                }

                if (caller.proxy() != null && caller.proxy().length() > 0) {
                    referenceConfig.setProxy(caller.proxy());
                }

                if (caller.filter() != null && caller.filter().length() > 0) {
                    referenceConfig.setFilter(caller.filter());
                }


                if (caller.actives() > 0) {
                    referenceConfig.setActives(caller.actives());
                }

                if (caller.async()) {
                    referenceConfig.setAsync(caller.async());
                }


                if (caller.mock() != null && caller.mock().length() > 0) {
                    referenceConfig.setMock(caller.mock());
                }

                if (caller.shareChannel()) {
                    referenceConfig.setShareChannel(caller.shareChannel());
                }

                // if throw exception when call failure，the default value is ture
                if (caller.throwException()) {
                    referenceConfig.setThrowException(caller.throwException());
                }
                if(caller.requestTimeout()>0) {
                    referenceConfig.setRequestTimeout(caller.requestTimeout());
                }
                if (caller.register()) {
                    referenceConfig.setRegister(caller.register());
                }
                if (caller.accessLog()) {
                    referenceConfig.setAccessLog("true");
                }
                if (caller.check()) {
                    referenceConfig.setCheck("true");
                }
                if(caller.retries()>0) {
                    referenceConfig.setRetries(caller.retries());
                }
                if (caller.usegz()) {
                    referenceConfig.setUsegz(caller.usegz());
                }
                if(caller.mingzSize()>0) {
                    referenceConfig.setMingzSize(caller.mingzSize());
                }
                if (caller.codec() != null && caller.codec().length() > 0) {
                    referenceConfig.setCodec(caller.codec());
                }


                if (caller.mean() != null && caller.mean().length() > 0) {
                    referenceConfig.setMean(caller.mean());
                }
                if (caller.p90() != null && caller.p90().length() > 0) {
                    referenceConfig.setP90(caller.p90());
                }
                if (caller.p99() != null && caller.p99().length() > 0) {
                    referenceConfig.setP99(caller.p99());
                }
                if (caller.p999() != null && caller.p999().length() > 0) {
                    referenceConfig.setP999(caller.p999());
                }
                if (caller.errorRate() != null && caller.errorRate().length() > 0) {
                    referenceConfig.setErrorRate(caller.errorRate());
                }

                try {
                    referenceConfig.afterPropertiesSet();
                } catch (RuntimeException e) {
                    throw (RuntimeException) e;
                } catch (Exception e) {
                    throw new IllegalStateException(e.getMessage(), e);
                }
            }
            callerConfigs.putIfAbsent(key, referenceConfig);
            referenceConfig = callerConfigs.get(key);
        }

        return referenceConfig.getRef();
    
    }


    private boolean isMatchPackage(Object bean) {
        if (annotationPackages == null || annotationPackages.length == 0) {
            return true;
        }
        Class<?> clazz = bean.getClass();
        if (isProxyBean(bean)) {
            clazz = AopUtils.getTargetClass(bean);
        }
        String beanClassName = clazz.getName();
        for (String pkg : annotationPackages) {
            if (beanClassName.startsWith(pkg)) {
                return true;
            }
        }
        return false;
    }

    private boolean isProxyBean(Object bean) {
        return AopUtils.isAopProxy(bean);
    }

    public String getPackage() {
        return annotationPackage;
    }

    public void setPackage(String annotationPackage) {
        this.annotationPackage = annotationPackage;
        this.annotationPackages = (annotationPackage == null || annotationPackage.length() == 0) ? null
                : annotationPackage.split(SpringBeanUtil.COMMA_SPLIT_PATTERN);
    }

    public String getId() {
        return id;
    }

    public void setId(String id) {
        this.id = id;
    }

    @Override
    public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
        this.beanFactory = beanFactory;
    }
}
